from typing import List, Optional
import fire
from llama import Llama, Dialog
from argparse import ArgumentParser
from torch.utils.data import DataLoader
from tools import read_jsonl, DynamicDataset, collate_fn
import jsonlines
from tqdm import tqdm


def hyperparameters():
    parser = ArgumentParser(description="Run ORCA2 for Rule Probing")

    parser.add_argument("--model", type=str, default="", help="Model path")
    parser.add_argument("--model_name", type=str, default="13B-Chat")
    parser.add_argument("--data_dir", type=str, default="ecare")
    parser.add_argument("--data_file", type=str, default="all")
    parser.add_argument("--output_dir", type=str, default="output")
    parser.add_argument("--model_size", type=str, default="13B")

    parser.add_argument("--batch_size", type=int, default=12)
    parser.add_argument("--max_seq_len", type=int, default=512)
    parser.add_argument("--max_gen_len", type=int, default=10)
    parser.add_argument("--temperature", type=float, default=0.00001)

    return parser.parse_args()

    
if __name__ == "__main__":
    args = hyperparameters()

    print(args)
    print(f"Model: {args.model_name}")

    generator = Llama.build(
        ckpt_dir=f"{args.model}/{args.model_name}",
        tokenizer_path=f"{args.model}/tokenizer.model",
        max_seq_len=args.max_seq_len,
        max_batch_size=args.batch_size,
        model_parallel_size=2
    )

    data = read_jsonl(f"./data/{args.data_dir}/{args.data_file}_full.jsonl")
    dataset = DynamicDataset(*data)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    fo = jsonlines.open(f"./{args.output_dir}/{args.model_name}_v2.jsonl", "w")

    system_message = "You are an AI language model. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."

    for batch in tqdm(dataloader):
        premises, hypotheses1, hypotheses2, rules, labels = batch
        dialogs: List[Dialog] = [[{"role": "system", "content": system_message}, {"role": "user", "content": f"Question: {p} Hypothesis1 or Hypothesis2?\nHypothesis1: {h1}\nHypothesis2: {h2} \nDo you think \"{r}\" can be used to answer this question? You answer should follow the format like \"Answer: Yes or No.\""}] for p, h1, h2, r, l in zip(premises, hypotheses1, hypotheses2, rules, labels)]
        # dialogs: List[Dialog] = [[{"role": "system", "content": system_message}, {"role": "user", "content": f"{p} Hypothesis1 or Hypothesis2?\nHypothesis1: {h1}\nHypothesis2: {h2} \nYou answer should follow the format like \"Answer: Hypothesis(1 or 2) is more plausible.\nExplanation: ___\""}] for p, h1, h2, r, l in zip(premises, hypotheses1, hypotheses2, rules, labels)]

        results = generator.chat_completion(
            dialogs,
            max_gen_len=args.max_gen_len,
            temperature=args.temperature,
            # do_sample=True,
            # top_k=50
        )
                
        for result, p, h1, h2, r, l  in zip(results, premises, hypotheses1, hypotheses2, rules, labels.tolist()):
            fo.write({"premise": p, "hypothesis1": h1, "hypothesis2": h2, "general_rule": r, "label": l, "answer": result['generation']['content']})
